{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "/home/huangweilin/anaconda3/envs/fjw/lib/python3.6/site-packages/dgl/base.py:25: UserWarning: Detected an old version of PyTorch. Suggest using torch>=1.2.0 for the best experience.\n",
      "  warnings.warn(msg, warn_type)\n"
     ]
    }
   ],
   "source": [
    "import dgl\n",
    "import torch\n",
    "from transformers import *\n",
    "from transformers import AdamW\n",
    "import torch.utils.data as Data\n",
    "import collections\n",
    "import os\n",
    "import random\n",
    "import tarfile\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchtext.vocab as Vocab\n",
    "import pickle as pk\n",
    "import copy\n",
    "import time\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import torch.nn.functional as F\n",
    "from IPython.display import display,HTML\n",
    "import os\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.nn.utils.rnn import pack_padded_sequence\n",
    "from torch.nn.utils.rnn import pad_packed_sequence\n",
    "from torch.nn.utils.rnn import pack_sequence\n",
    "from torch.nn import CrossEntropyLoss, MSELoss\n",
    "import math\n",
    "device=torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "import argparse\n",
    "import glob\n",
    "import json\n",
    "import logging\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import torch.utils.data as Data\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from dgl import DGLGraph\n",
    "\n",
    "gcn_msg = fn.copy_src(src='h', out='m')\n",
    "gcn_reduce = fn.sum(msg='m', out='h')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## GCN GAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "code_folding": [
     0,
     14,
     50,
     66
    ],
    "hidden": true
   },
   "outputs": [],
   "source": [
    "class GCNLayer(nn.Module):\n",
    "    def __init__(self, in_feats, out_feats):\n",
    "        super(GCNLayer, self).__init__()\n",
    "        self.linear = nn.Linear(in_feats, out_feats)\n",
    "\n",
    "    def forward(self, g, feature):\n",
    "        # Creating a local scope so that all the stored ndata and edata\n",
    "        # (such as the `'h'` ndata below) are automatically popped out\n",
    "        # when the scope exits.\n",
    "        with g.local_scope():\n",
    "            g.ndata['h'] = feature\n",
    "            g.update_all(gcn_msg, gcn_reduce)\n",
    "            h = g.ndata['h']\n",
    "            return self.linear(h)\n",
    "class GATLayer(nn.Module):\n",
    "    def __init__(self, g, in_dim, out_dim):\n",
    "        super(GATLayer, self).__init__()\n",
    "        self.g = g\n",
    "        # equation (1)\n",
    "        self.fc = nn.Linear(in_dim, out_dim, bias=False)\n",
    "        # equation (2)\n",
    "        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)\n",
    "\n",
    "    def edge_attention(self, edges):\n",
    "        # edge UDF for equation (2)\n",
    "        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)\n",
    "        a = self.attn_fc(z2)\n",
    "        return {'e': F.leaky_relu(a)}\n",
    "\n",
    "    def message_func(self, edges):\n",
    "        # message UDF for equation (3) & (4)\n",
    "        return {'z': edges.src['z'], 'e': edges.data['e']}\n",
    "\n",
    "    def reduce_func(self, nodes):\n",
    "        # reduce UDF for equation (3) & (4)\n",
    "        # equation (3)\n",
    "        alpha = F.softmax(nodes.mailbox['e'], dim=1)\n",
    "        # equation (4)\n",
    "        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n",
    "        return {'h': h}\n",
    "\n",
    "    def forward(self, h):\n",
    "        # equation (1)\n",
    "        z = self.fc(h)\n",
    "        self.g.ndata['z'] = z\n",
    "        # equation (2)\n",
    "        self.g.apply_edges(self.edge_attention)\n",
    "        # equation (3) & (4)\n",
    "        self.g.update_all(self.message_func, self.reduce_func)\n",
    "        return self.g.ndata.pop('h')\n",
    "class MultiHeadGATLayer(nn.Module):\n",
    "    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):\n",
    "        super(MultiHeadGATLayer, self).__init__()\n",
    "        self.heads = nn.ModuleList()\n",
    "        for i in range(num_heads):\n",
    "            self.heads.append(GATLayer(g, in_dim, out_dim))\n",
    "        self.merge = merge\n",
    "\n",
    "    def forward(self, h):\n",
    "        head_outs = [attn_head(h) for attn_head in self.heads]\n",
    "        if self.merge == 'cat':\n",
    "            # concat on the output feature dimension (dim=1)\n",
    "            return torch.cat(head_outs, dim=1)\n",
    "        else:\n",
    "            # merge using average\n",
    "            return torch.mean(torch.stack(head_outs))\n",
    "class GAT(nn.Module):\n",
    "    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):\n",
    "        super(GAT, self).__init__()\n",
    "        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)\n",
    "        # Be aware that the input dimension is hidden_dim*num_heads since\n",
    "        # multiple head outputs are concatenated together. Also, only\n",
    "        # one attention head in the output layer.\n",
    "        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)\n",
    "\n",
    "    def forward(self, h):\n",
    "        h = self.layer1(h)\n",
    "        h = F.elu(h)\n",
    "        h = self.layer2(h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "code_folding": [
     0
    ],
    "hidden": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self,input_features,hidden_size,num_labels):\n",
    "        super(Net, self).__init__()\n",
    "        self.layer1 = GCNLayer(input_features,hidden_size)\n",
    "        self.layer2 = GCNLayer(hidden_size, num_labels)\n",
    "    \n",
    "    def forward(self, g, features):\n",
    "        x = F.relu(self.layer1(g, features))\n",
    "        x = self.layer2(g, x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "code_folding": [
     2
    ],
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from dgl.data import citation_graph as citegrh\n",
    "import networkx as nx\n",
    "def load_cora_data():\n",
    "    data = citegrh.load_cora()\n",
    "    features = th.FloatTensor(data.features)\n",
    "    labels = th.LongTensor(data.labels)\n",
    "    train_mask = th.BoolTensor(data.train_mask)\n",
    "    test_mask = th.BoolTensor(data.test_mask)\n",
    "    g = DGLGraph(data.graph)\n",
    "    return g, features, labels, train_mask, test_mask\n",
    "def build_graph():\n",
    "    num_of_nodes=max(pd.read_csv(\"./text_classify/data/text.csv\")[\"id\"])+1\n",
    "    g = dgl.DGLGraph()\n",
    "    g.add_nodes(num_of_nodes)\n",
    "    edge_list=pd.read_csv(\"./text_classify/data/reference.csv\").values\n",
    "\n",
    "    src, dst = tuple(zip(*edge_list))\n",
    "    g.add_edges(src, dst)\n",
    "    full_labeled_nodes=pd.read_csv(\"./text_classify/data/train.csv\")['id'].values\n",
    "    full_labels=pd.read_csv(\"./text_classify/data/train.csv\")['label'].values\n",
    "    id2features=pk.load(open(\"./text_classify/data/bert_hidden_and_probs.pk\",\"rb\"))\n",
    "    features=np.zeros((len(id2features),768))\n",
    "    for key,value in id2features.items():\n",
    "        features[key]=value[1]\n",
    "    features=torch.tensor(features).float()\n",
    "    labels=np.zeros(g.number_of_nodes())\n",
    "    labels[full_labeled_nodes]=full_labels\n",
    "    labels = torch.LongTensor(labels)  # their labels are different\n",
    "    train_labeled_nodes=full_labeled_nodes[:10000]\n",
    "    val_labeled_nodes=full_labeled_nodes[10000:]\n",
    "    idxs=np.zeros(g.number_of_nodes())\n",
    "    idxs[train_labeled_nodes]=1\n",
    "    train_mask=torch.BoolTensor(idxs)\n",
    "    idxs=np.zeros(g.number_of_nodes())\n",
    "    idxs[val_labeled_nodes]=1\n",
    "    val_mask=torch.BoolTensor(idxs)\n",
    "    idxs=np.ones(g.number_of_nodes())\n",
    "    idxs[full_labeled_nodes]=0\n",
    "    test_mask=torch.BoolTensor(idxs)\n",
    "    return g, features, labels, train_mask, val_mask,test_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def evaluate(model, g, features, labels, mask):\n",
    "    model.eval()\n",
    "    with th.no_grad():\n",
    "#         logits = model( features) #GAT\n",
    "        logits=model(g,features) #GCN\n",
    "        logits = logits[mask]\n",
    "        labels = labels[mask]\n",
    "        _, indices = th.max(logits, dim=1)\n",
    "        correct = th.sum(indices == labels)\n",
    "        return correct.item() * 1.0 / len(labels)\n",
    "def predict(model, g, features, labels, mask):\n",
    "    model.eval()\n",
    "    pred_indices=[]\n",
    "    with th.no_grad():\n",
    "#         logits = model( features) #GAT\n",
    "        logits=model(g,features) #GCN\n",
    "        logits = logits[mask]\n",
    "        labels = labels[mask]\n",
    "        _, indices = th.max(logits, dim=1)\n",
    "        pred_indices.append(indices.detach().cpu().numpy())\n",
    "        return  np.concatenate(pred_indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 建图及读取特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have 25561 nodes.\n",
      "We have 73313 edges.\n"
     ]
    }
   ],
   "source": [
    "g, features, labels, train_mask, val_mask,test_mask=build_graph()\n",
    "print('We have %d nodes.' % g.number_of_nodes())\n",
    "print('We have %d edges.' % g.number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([25561, 768])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "net = GAT(g,\n",
    "          in_dim=features.size()[1],\n",
    "          hidden_dim=384,\n",
    "          out_dim=5,\n",
    "          num_heads=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "net=Net(768,384,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (layer1): GCNLayer(\n",
       "    (linear): Linear(in_features=768, out_features=384, bias=True)\n",
       "  )\n",
       "  (layer2): GCNLayer(\n",
       "    (linear): Linear(in_features=384, out_features=5, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00000 | Loss 1.3490 | Test Acc 0.4192 | Time(s) nan\n",
      "Epoch 00001 | Loss 1.4897 | Test Acc 0.4509 | Time(s) nan\n",
      "Epoch 00002 | Loss 1.3383 | Test Acc 0.4520 | Time(s) nan\n",
      "Epoch 00003 | Loss 1.3357 | Test Acc 0.4473 | Time(s) 1.1836\n",
      "Epoch 00004 | Loss 1.3587 | Test Acc 0.4484 | Time(s) 1.1631\n",
      "Epoch 00005 | Loss 1.3597 | Test Acc 0.4462 | Time(s) 1.2757\n",
      "Epoch 00006 | Loss 1.3558 | Test Acc 0.4476 | Time(s) 1.2471\n",
      "Epoch 00007 | Loss 1.3509 | Test Acc 0.4491 | Time(s) 1.2354\n",
      "Epoch 00008 | Loss 1.3342 | Test Acc 0.4484 | Time(s) 1.2589\n",
      "Epoch 00009 | Loss 1.3096 | Test Acc 0.4480 | Time(s) 1.2637\n",
      "Epoch 00010 | Loss 1.2844 | Test Acc 0.4494 | Time(s) 1.2718\n",
      "Epoch 00011 | Loss 1.2675 | Test Acc 0.4502 | Time(s) 1.2680\n",
      "Epoch 00012 | Loss 1.2619 | Test Acc 0.4494 | Time(s) 1.3304\n",
      "Epoch 00013 | Loss 1.2559 | Test Acc 0.4476 | Time(s) 1.3641\n",
      "Epoch 00014 | Loss 1.2445 | Test Acc 0.4473 | Time(s) 1.3706\n",
      "Epoch 00015 | Loss 1.2337 | Test Acc 0.4476 | Time(s) 1.3608\n",
      "Epoch 00016 | Loss 1.2283 | Test Acc 0.4466 | Time(s) 1.3485\n",
      "Epoch 00017 | Loss 1.2302 | Test Acc 0.4476 | Time(s) 1.3321\n",
      "Epoch 00018 | Loss 1.2294 | Test Acc 0.4480 | Time(s) 1.3286\n",
      "Epoch 00019 | Loss 1.2240 | Test Acc 0.4494 | Time(s) 1.3184\n",
      "Epoch 00020 | Loss 1.2191 | Test Acc 0.4527 | Time(s) 1.3123\n",
      "Epoch 00021 | Loss 1.2139 | Test Acc 0.4523 | Time(s) 1.3037\n",
      "Epoch 00022 | Loss 1.2095 | Test Acc 0.4545 | Time(s) 1.3366\n",
      "Epoch 00023 | Loss 1.2084 | Test Acc 0.4552 | Time(s) 1.3238\n",
      "Epoch 00024 | Loss 1.2080 | Test Acc 0.4545 | Time(s) 1.3090\n",
      "Epoch 00025 | Loss 1.2059 | Test Acc 0.4559 | Time(s) 1.3019\n",
      "Epoch 00026 | Loss 1.2024 | Test Acc 0.4556 | Time(s) 1.2988\n",
      "Epoch 00027 | Loss 1.1991 | Test Acc 0.4552 | Time(s) 1.2953\n",
      "Epoch 00028 | Loss 1.1970 | Test Acc 0.4559 | Time(s) 1.2944\n",
      "Epoch 00029 | Loss 1.1953 | Test Acc 0.4545 | Time(s) 1.2825\n",
      "Epoch 00030 | Loss 1.1930 | Test Acc 0.4548 | Time(s) 1.2805\n",
      "Epoch 00031 | Loss 1.1904 | Test Acc 0.4559 | Time(s) 1.2734\n",
      "Epoch 00032 | Loss 1.1890 | Test Acc 0.4552 | Time(s) 1.2699\n",
      "Epoch 00033 | Loss 1.1880 | Test Acc 0.4548 | Time(s) 1.2726\n",
      "Epoch 00034 | Loss 1.1864 | Test Acc 0.4538 | Time(s) 1.2690\n",
      "Epoch 00035 | Loss 1.1852 | Test Acc 0.4523 | Time(s) 1.2711\n",
      "Epoch 00036 | Loss 1.1836 | Test Acc 0.4527 | Time(s) 1.2800\n",
      "Epoch 00037 | Loss 1.1817 | Test Acc 0.4538 | Time(s) 1.2886\n",
      "Epoch 00038 | Loss 1.1802 | Test Acc 0.4541 | Time(s) 1.2916\n",
      "Epoch 00039 | Loss 1.1790 | Test Acc 0.4538 | Time(s) 1.3006\n",
      "Epoch 00040 | Loss 1.1776 | Test Acc 0.4538 | Time(s) 1.2943\n",
      "Epoch 00041 | Loss 1.1763 | Test Acc 0.4559 | Time(s) 1.2976\n",
      "Epoch 00042 | Loss 1.1746 | Test Acc 0.4556 | Time(s) 1.2978\n",
      "Epoch 00043 | Loss 1.1727 | Test Acc 0.4559 | Time(s) 1.2923\n",
      "Epoch 00044 | Loss 1.1711 | Test Acc 0.4556 | Time(s) 1.2894\n",
      "Epoch 00045 | Loss 1.1696 | Test Acc 0.4552 | Time(s) 1.2872\n",
      "Epoch 00046 | Loss 1.1684 | Test Acc 0.4548 | Time(s) 1.2883\n",
      "Epoch 00047 | Loss 1.1672 | Test Acc 0.4556 | Time(s) 1.2854\n",
      "Epoch 00048 | Loss 1.1658 | Test Acc 0.4559 | Time(s) 1.2811\n",
      "Epoch 00049 | Loss 1.1644 | Test Acc 0.4552 | Time(s) 1.2767\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "# g, features, labels, train_mask, test_mask = load_cora_data()\n",
    "optimizer = th.optim.Adam(net.parameters(), lr=1e-3)\n",
    "dur = []\n",
    "# net=net.to(device)\n",
    "# g=g.to(device)\n",
    "# features=features.to(device)\n",
    "# train_mask=train_mask.to(device)\n",
    "# test_mask=test_mask.to(device)\n",
    "for epoch in range(50):\n",
    "    if epoch >=3:\n",
    "        t0 = time.time()\n",
    "\n",
    "    net.train()\n",
    "#     logits = net(features) #GAT\n",
    "    logits=net(g,features) #GCN\n",
    "    logp = F.log_softmax(logits, 1)\n",
    "    logp=logp.cpu()\n",
    "    loss = F.nll_loss(logp[train_mask], labels[train_mask])\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if epoch >=3:\n",
    "        dur.append(time.time() - t0)\n",
    "    \n",
    "    acc = evaluate(net, g, features, labels, val_mask)\n",
    "    print(\"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}\".format(\n",
    "            epoch, loss.item(), acc, np.mean(dur)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huangweilin/anaconda3/envs/fjw/lib/python3.6/site-packages/ipykernel_launcher.py:5: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  \"\"\"\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9e6d46e89a64296ac61eaf612822323",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=12782), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "test_mask_array=test_mask.to(torch.long).detach().numpy()\n",
    "test_ids=list(filter(lambda x:x>0,list(np.arange(test_mask_array.shape[0])*test_mask_array)))\n",
    "test_pred_labels=predict(net,g,features,labels,test_mask)\n",
    "test_submit=pd.read_csv(\"./text_classify/data/sample.csv\")\n",
    "for idx in tqdm(range(len(test_ids))):\n",
    "    index=test_submit.query(\"id==%d\"%test_ids[idx])['label'].index[0]\n",
    "    test_submit.loc[index,'label']=test_pred_labels[idx]\n",
    "test_submit.to_csv(\"./text_classify/data/bert_gcn.csv\",index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fjw",
   "language": "python",
   "name": "fjw"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
